{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b8ba4873",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T07:02:33.566772Z",
     "start_time": "2024-05-21T07:02:33.561772Z"
    }
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np\n",
    "from IPython import display\n",
    "import matplotlib\n",
    "from gym.envs.toy_text.frozen_lake import generate_random_map\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1dd0863c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T06:33:36.290500Z",
     "start_time": "2024-05-21T06:33:36.282484Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 8, 12])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img=torch.randn(1,3,8,12)\n",
    "weight=torch.randn(6,3,3,3)\n",
    "bias=torch.randn(6)\n",
    "#funcitonal api\n",
    "output=F.conv2d(img,weight,bias,padding=1)\n",
    "output=output.squeeze()\n",
    "output.shape\n",
    "# plt.imshow(output)\n",
    "torch.utils.data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0846f9e6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T07:02:34.682641Z",
     "start_time": "2024-05-21T07:02:34.674085Z"
    }
   },
   "outputs": [],
   "source": [
    "class GymHelper:\n",
    "    def __init__(self,env,figsize=(3,3)):\n",
    "        self.env=env\n",
    "        self.figsize=figsize\n",
    "        plt.figure(figsize=figsize)\n",
    "        self.img=plt.imshow(env.render())\n",
    "    def render(self,title=None):\n",
    "        img_data=self.env.render()\n",
    "        self.img.set_data(img_data)\n",
    "        display.display(plt.gcf())\n",
    "        display.clear_output(wait=True)\n",
    "        if title:\n",
    "            plt.title(title)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3174d0ba",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T07:02:35.529784Z",
     "start_time": "2024-05-21T07:02:35.523256Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from tqdm import *\n",
    "import collections\n",
    "import time\n",
    "import random\n",
    "import sys\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "092cf3b5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T07:02:39.896778Z",
     "start_time": "2024-05-21T07:02:36.354240Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEnCAYAAACQfkeNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjUklEQVR4nO3df3hT9d038PdJ86NpmsT+gMTYggUqCgWU4o8yxw8pVWdlPLsvZcMpe+TaDUorvYDLgW6DudkynPjjYcClMtw1H1fdhVU2saMKVLnRRywUWtjYvVmgLY0V16ZJmyZp8n3+UHKbtpSktDnn1Pfrus4fPfk0/QTom3O+53u+RxJCCBARxZlG7gaI6JuJ4UNEsmD4EJEsGD5EJAuGDxHJguFDRLJg+BCRLBg+RCQLhg8RyYLhQ2GvvfYaJk+eDKPRCEmSsHDhQkiSNKj3OnDgACRJwoEDB2L6vquvvhqFhYWD+pmxOH36NCRJwssvvzzsP4v6p5W7AVKGzz//HPfffz/uuOMObN26FQaDAQ6HA2vXrh3U+02fPh0ffvghJk2aNMSd0kjB8CEAwD/+8Q8EAgH88Ic/xOzZs8P7x4wZM6j3s1gsuOWWW4aqPRqBeNpF+NGPfoRbb70VALBo0SJIkoQ5c+Zgw4YNfU67LpwWVVZWYvr06TAajbj22mvxu9/9LqKuv9OuTz/9FN///vfhcDhgMBhgs9kwb9481NbW9unpUu8PAE6nE8uWLUNGRgb0ej2ysrLwi1/8Aj09PRF1586dw7333guz2Qyr1YpFixbB6XQO8k+LhgqPfAg/+9nPcNNNN2HFihUoLS3F3LlzYbFY8Prrr/dbf+zYMaxevRpr166FzWbDSy+9hKVLl2LChAmYNWvWRX/Od77zHQSDQWzatAljxozB+fPncejQIbS3t8f8/k6nEzfddBM0Gg1+/vOfY/z48fjwww/xq1/9CqdPn8bOnTsBAF6vF/n5+Th37hzKyspwzTXX4O2338aiRYuG5g+PBk8QCSH2798vAIg//elP4X3r168Xvf+JjB07ViQmJoozZ86E93m9XpGamiqWLVvW5/32798vhBDi/PnzAoB49tlnB+wj2vdftmyZSE5OjqgTQojf/OY3AoA4ceKEEEKIbdu2CQDirbfeiqj78Y9/LACInTt3DtgPDR+edlHMrr/++oixoMTERFxzzTU4c+bMRb8nNTUV48ePx1NPPYXNmzfj6NGjCIVCg37/v/zlL5g7dy4cDgd6enrC25133gkAqK6uBgDs378fZrMZCxYsiPgZixcvjv2D05Bi+FDM0tLS+uwzGAzwer0X/R5JkvDee+/h9ttvx6ZNmzB9+nSMGjUKjzzyCNxud8zv/9lnn+HPf/4zdDpdxDZ58mQAwPnz5wEAX3zxBWw2W5/3s9vt0X1YGjYc86G4GTt2LHbs2AHgy6trr7/+OjZs2AC/34/t27fH9F7p6emYOnUqnnzyyX5fdzgcAL4Mso8//rjP6xxwlh/Dh2RxzTXX4Kc//Sl27dqFI0eOxPz9hYWF2LNnD8aPH4+UlJSL1s2dOxevv/46du/eHXHq9eqrrw6qbxo6DB+Ki+PHj6OoqAj33HMPsrOzodfrsW/fPhw/fnxQExmfeOIJVFVVYebMmXjkkUcwceJEdHd34/Tp09izZw+2b9+OjIwMPPDAA3jmmWfwwAMP4Mknn0R2djb27NmDv/71r8PwKSkWDB+KC7vdjvHjx2Pr1q1obGyEJEkYN24cnn76aRQXF8f8fldeeSU++eQT/PKXv8RTTz2FpqYmmM1mZGVl4Y477ggfDSUlJWHfvn1YuXIl1q5dC0mSUFBQgPLycsycOXOoPybFQBKCT68govjj1S4ikgXDh4hkwfAhIlnIGj5bt25FVlYWEhMTkZubiw8++EDOdogojmQLn9deew0lJSV4/PHHcfToUXz729/GnXfeibNnz8rVEhHFkWxXu26++WZMnz4d27ZtC++77rrrsHDhQpSVlcnREhHFkSzzfPx+P2pqavpMLisoKMChQ4cu+f2hUAjnzp2D2Wwe9DKfRDT0hBBwu91wOBzQaAY+sZIlfM6fP49gMNjnhj+bzdbvPTc+nw8+ny/8dXNzM5fnJFKwxsZGZGRkDFgj6wzn3kctQoh+j2TKysrwi1/8os/+xsZGWCyWYeuPiGLT0dGBzMxMmM3mS9bKEj7p6elISEjoc5TT2tra7/IH69atw6pVq8JfX/iAFouF4UOkQNEMh8hytUuv1yM3NxdVVVUR+y/cKNibwWAIBw0Dh2hkkO20a9WqVbj//vsxY8YM5OXl4YUXXsDZs2exfPlyuVoiojiSLXwWLVqEL774Ak888QRaWlqQk5ODPXv2YOzYsXK1RERxpMq72js6OmC1WuFyuXgKRqQgsfxu8t4uIpIFw4eIZMHwISJZMHyISBYMHyKSBcOHiGTB8CEiWTB8iEgWDB8ikgXDh4hkwfAhIlkwfIhIFgwfIpIFw4eIZMHwISJZMHyISBYMHyKSBcOHiGTB8CEiWTB8iEgWDB8ikgXDh4hkwfAhIlkwfIhIFgwfIpIFw4eIZMHwISJZMHyISBYMHyKSBcOHiGTB8CEiWTB8iEgWDB8ikgXDh4hkwfAhIlnEHD7vv/8+7r77bjgcDkiShDfffDPidSEENmzYAIfDAaPRiDlz5uDEiRMRNT6fD8XFxUhPT4fJZMKCBQvQ1NR0WR+EiNQl5vDp7OzEtGnTsGXLln5f37RpEzZv3owtW7bg8OHDsNvtmD9/Ptxud7impKQEFRUVKC8vx8GDB+HxeFBYWIhgMDj4T0JE6iIuAwBRUVER/joUCgm73S42btwY3tfd3S2sVqvYvn27EEKI9vZ2odPpRHl5ebimublZaDQaUVlZGdXPdblcAoBwuVyX0z4RDbFYfjeHdMynoaEBTqcTBQUF4X0GgwGzZ8/GoUOHAAA1NTUIBAIRNQ6HAzk5OeEaIhr5tEP5Zk6nEwBgs9ki9ttsNpw5cyZco9frkZKS0qfmwvf35vP54PP5wl93dHQMZdtEJINhudolSVLE10KIPvt6G6imrKwMVqs1vGVmZg5Zr0QkjyENH7vdDgB9jmBaW1vDR0N2ux1+vx9tbW0Xrelt3bp1cLlc4a2xsXEo2yYiGQxp+GRlZcFut6Oqqiq8z+/3o7q6GjNnzgQA5ObmQqfTRdS0tLSgvr4+XNObwWCAxWKJ2IhI3WIe8/F4PPjnP/8Z/rqhoQG1tbVITU3FmDFjUFJSgtLSUmRnZyM7OxulpaVISkrC4sWLAQBWqxVLly7F6tWrkZaWhtTUVKxZswZTpkxBfn7+0H0yIlK2WC+l7d+/XwDosy1ZskQI8eXl9vXr1wu73S4MBoOYNWuWqKuri3gPr9crioqKRGpqqjAajaKwsFCcPXs26h54qZ1ImWL53ZSEEELG7BuUjo4OWK1WuFwunoIRKUgsv5u8t4uIZMHwISJZMHyISBYMHyKSBcOHiGTB8CEiWTB8iEgWDB8ikgXDh4hkMaTr+cTbsWPHkJycLHcbRPQVj8cTda2qw+fzzz9HV1eX3G0Q0Vc6OzujrlV1+OTn5/PeLiIFiWWVUY75EJEsGD5EJAuGDxHJguFDRLJg+BCRLBg+RCQLhg8RyYLhQ0SyYPgQkSwYPkQkC4YPEcmC4UNEsmD4EJEsGD5EJAuGDxHJguFDRLJg+BCRLBg+RCQLhg8RyYLhQ0SyYPgQkSwYPkQkC4YPEckipvApKyvDjTfeCLPZjNGjR2PhwoU4depURI0QAhs2bIDD4YDRaMScOXNw4sSJiBqfz4fi4mKkp6fDZDJhwYIFaGpquvxPQ0SqEVP4VFdXY8WKFfjoo49QVVWFnp4eFBQURDylcNOmTdi8eTO2bNmCw4cPw263Y/78+XC73eGakpISVFRUoLy8HAcPHoTH40FhYSGCweDQfTIiUjZxGVpbWwUAUV1dLYQQIhQKCbvdLjZu3Biu6e7uFlarVWzfvl0IIUR7e7vQ6XSivLw8XNPc3Cw0Go2orKyM6ue6XC4BQLhcrstpn4iGWCy/m5c15uNyuQAAqampAICGhgY4nU4UFBSEawwGA2bPno1Dhw4BAGpqahAIBCJqHA4HcnJywjVENPIN+lntQgisWrUKt956K3JycgAATqcTAGCz2SJqbTYbzpw5E67R6/VISUnpU3Ph+3vz+Xzw+Xzhr2N5HjQRKdOgj3yKiopw/Phx/PGPf+zzmiRJEV8LIfrs622gmrKyMlit1vCWmZk52LaJSCEGFT7FxcXYvXs39u/fj4yMjPB+u90OAH2OYFpbW8NHQ3a7HX6/H21tbRet6W3dunVwuVzhrbGxcTBtE5GCxBQ+QggUFRXhjTfewL59+5CVlRXxelZWFux2O6qqqsL7/H4/qqurMXPmTABAbm4udDpdRE1LSwvq6+vDNb0ZDAZYLJaIjYjULaYxnxUrVuDVV1/FW2+9BbPZHD7CsVqtMBqNkCQJJSUlKC0tRXZ2NrKzs1FaWoqkpCQsXrw4XLt06VKsXr0aaWlpSE1NxZo1azBlyhTk5+cP/SckImWK5TIagH63nTt3hmtCoZBYv369sNvtwmAwiFmzZom6urqI9/F6vaKoqEikpqYKo9EoCgsLxdmzZ6Pug5faiZQplt9NSQgh5Iu+weno6IDVaoXL5eIpGJGCxPK7yXu7iEgWDB8ikgXDh4hkwfAhIlkwfIhIFgwfIpIFw4eIZMHwISJZMHyISBYMHyKSxaAXEyO6XEIEEQy6EAp5IUlaJCRYIEmJl1z7iUYGhg/FnRACgUAzPv98O1yuP8Pvb4RGY0JS0gyMHv0wzOa5kCT+0xzp+DdMcSWEgM/33zh9+n+js/NDfLkwAhAMtsHlaoLHsx8ORylGjfpPBtAIxzEfiqtg0IWzZ4vR2XkIF4Kn9+vNzY/B5aqEChdcoBgwfCiuXK634Xa/N2BNKOTCZ589jVCoc8A6UjeGD8WVx3MQwKUfDtnZdQQdgfZh74fkw/AhRfKGQnjf4750IakWw4fipr2nB9Vdop+Rnr7cSESd148gx31GLIYPxU23ENgRKEAHBl5eUwCoxO3Y5U5AVygUn+Yo7hg+FDejtFpMsFyPF/Cf6Iah3xoB4ChuwP/FD9EcCKKtpye+TVLcMHwobhIkCVkGIyrwH/g/KEYrRiEIKfwYFB/0qMYsPIGf4zxG4d89PdjndvOS+wjFWVwUV3dbrXiiRY/Xxb34EHm4BR8hA03ohAnHMA3HMA1eJAH48prYpz4fBADecDHyMHworkZptcjU6fCpX+AsxuIsxg5Y/2Z7Ox6/8koYeL/XiMPTLooru06HqUlJUde7QiG0By89L4jUh+FDcaWRJNxqMkVd7wwEUNPVNYwdkVwYPhR3ecnJUY/h+IXAqe5uDjqPQAwfirvxBgOuMfR/qb0/u9vbo5qYSOrC8KG4G6XVYoxeH3X9uUAAbRz3GXEYPhR3EoC7r7gi6vp/+Xw4xnGfEYfhQ7KYnJgY9TyPIIC/cdxnxGH4UNxJkoRpSUkYF8O4z1vt7cPXEMmC4UOySElIQGYM4z5NgQDO8z6vEYXhQ7KQ8OWtFtH6e3c3Gvz+4WuI4o7hQ7KQJAnXJSYiMYbbJv7l8w1jRxRvDB+SzXSTCTadLqpaAeBtl4uDziMIw4dkY9ZocEVCQtT1n/p86ODiYiNGTOGzbds2TJ06FRaLBRaLBXl5eXjnnXfCrwshsGHDBjgcDhiNRsyZMwcnTpyIeA+fz4fi4mKkp6fDZDJhwYIFaGpqGppPQ6qilyTck5ISdf2Rri6c47jPiBFT+GRkZGDjxo345JNP8Mknn+C2227Dd7/73XDAbNq0CZs3b8aWLVtw+PBh2O12zJ8/H273/ywEXlJSgoqKCpSXl+PgwYPweDwoLCxEkDNYv5HG6vVRz/fpEQLHvN5h7YfiRxKXeRKdmpqKp556Cg8++CAcDgdKSkrwk5/8BMCXRzk2mw2//vWvsWzZMrhcLowaNQp/+MMfsGjRIgDAuXPnkJmZiT179uD222+P6md2dHTAarXC5XLBYhl4PWBSNmcggGknT6I1ysvoy9LTsW3MGD7PXaFi+d0c9JhPMBhEeXk5Ojs7kZeXh4aGBjidThQUFIRrDAYDZs+ejUOHDgEAampqEAgEImocDgdycnLCNf3x+Xzo6OiI2GhkSNZoMDaG+T7HvF6O+4wQMYdPXV0dkpOTYTAYsHz5clRUVGDSpElwOp0AAJvNFlFvs9nCrzmdTuj1eqT0Os//ek1/ysrKYLVaw1tmZmasbZNCJWk0mGs2R11/wutFB0/RR4SYw2fixImora3FRx99hIceeghLlizByZMnw6/3PhwWQlzyEPlSNevWrYPL5QpvjY2NsbZNCqWRJFyflARdlKdRASFwkuM+I0LM4aPX6zFhwgTMmDEDZWVlmDZtGp577jnY7XYA6HME09raGj4astvt8Pv9aGtru2hNfwwGQ/gK24WNRo48kynqyYbdQuADj4fzfUaAy57nI4SAz+dDVlYW7HY7qqqqwq/5/X5UV1dj5syZAIDc3FzodLqImpaWFtTX14dr6JsnVavFtBjWda7t6oKP4aN6MT294rHHHsOdd96JzMxMuN1ulJeX48CBA6isrIQkSSgpKUFpaSmys7ORnZ2N0tJSJCUlYfHixQAAq9WKpUuXYvXq1UhLS0NqairWrFmDKVOmID8/f1g+ICmfWaNBTmIiDno8UdV/0tUFTzCIRA3nyKpZTOHz2Wef4f7770dLSwusViumTp2KyspKzJ8/HwDw6KOPwuv14uGHH0ZbWxtuvvlm7N27F+avDSg+88wz0Gq1uPfee+H1ejFv3jy8/PLLSIhhpiuNLJIkId9iwYvnzyOaoWR3KISjXi/mR3lrBinTZc/zkQPn+Yw8p7q7Mf1vf4v62ey/ycjAqtGjOd9HYeIyz4doKDl0OuTGMO7zbkcHAur7f5O+huFDipCs0WBCDCsb/qO7G36Gj6oxfEgRJEnC3VZr1M/z+ncwiL93dw9rTzS8GD6kGNmJiTBFeQWrPRhEndfL+T4qxvAhxRin12NyYmLU9dUeDx8mqGIMH1IMo0aDNG30sz+OcLKhqjF8SFEWp6ZGXXvG5+N9XirG8CHFkCQJY/R66KOcu9MRCuG0389xH5Vi+JCiTE9KwsQYxn3e6HWTMqkHw4cUxRjj4mKn/X5ONlQphg8pigTgf11xRdT1R7u6cJLzfVSJ4UOKIkkSJhgMSI5yvk+3EHyMskoxfEhxbkhKiulhgpVc01uVGD6kOEaNJqabTE96vfBzUXnVYfiQ4iTgy6VV+xL9boe7utAcCMSxQxoKMS0mRhQPkiThJpMJZo0G7q+OaLQIYDpqcBv2YSzOohMmfIJcvId8tPeMRksggKwY7oon+TF8SJGuS0yEJSEB7lAIJniwHNvxXbyFJPzPjOZZeB/fxW78Gj9BRfto5JlMXFxMRXjaRYpk0mgwKzkZWgTwIHbgHvwpIngAQAOBCfgnfoZfQuv7G+f7qAzDhxRJJ0mYbDQiBydwL/4E7UVWd5YAjMFZjHE/D1cP5/uoCU+7SJEkScI8sxk+zT4khgYOFQnAjeIgTOLfAKK/Skby4pEPKVaWwYC0BCmq1Q2DIoh63uGuKgwfUqyUhAR8Kzk5qgXDuoXAR52dvMNdRRg+pFh6jQZtxjvQg4FvNBUADuNG/D2QBN5ooR4MH1K00ZZZqJTuRvAi/1QFgM9gw8v4Efa4fegMRvPYQVIChg8pmikhCTukFXgHd8DX6whIAGjGVSjDOpzEJHSGQvDwNgvV4NUuUjS7TofxSQ5s9KzFf+FbyMd7yEAjumDCEdyAPbgLZzAGgAYdwSD+y+PBohiWYiX5MHxI0UwaDew6HbqRhCoU4D3kQ/pqCDoEDQQk4KvrYX4h0BQIQAjBmc4qwNMuUjRJknBL+CZTCSEkIAgtgtBCQAP0uhBf09nJQWeVYPiQ4t1oMkX9JNOPu7rQw8vtqsDwIcVL12qRkpAQVW1XKIRWLq+hCgwfUrwxen3UT7RwBgL4f52dw9wRDQWGDyleoiRhdJRPMhUAvggGOdNZBRg+pHiSJOF2iyXq+r+0t4OzfZSP4UOqMD6GVQqbAwEEeeSjeAwfUoVxBgOujPLUqyUQwKc+3zB3RJfrssKnrKwMkiShpKQkvE8IgQ0bNsDhcMBoNGLOnDk4ceJExPf5fD4UFxcjPT0dJpMJCxYsQFNT0+W0QiNchl4f9eN0Wnt60MBnuCveoMPn8OHDeOGFFzB16tSI/Zs2bcLmzZuxZcsWHD58GHa7HfPnz4fb7Q7XlJSUoKKiAuXl5Th48CA8Hg8KCwsR5E2BdBF6SYI5ysvtAHjkowKDCh+Px4P77rsPL774IlJSUsL7hRB49tln8fjjj+N73/secnJy8Pvf/x5dXV149dVXAQAulws7duzA008/jfz8fNxwww145ZVXUFdXh3fffXdoPhWNOBKA/4jhMcpvtrcPVys0RAYVPitWrMBdd92F/Pz8iP0NDQ1wOp0oKCgI7zMYDJg9ezYOHToEAKipqUEgEIiocTgcyMnJCdf05vP50NHREbHRN489ytMuAPCEQvDxtEvRYg6f8vJyHDlyBGVlZX1eczqdAACbzRax32azhV9zOp3Q6/URR0y9a3orKyuD1WoNb5mZmbG2TSonSRJyjEaMinLQud7rxT+6uaC8ksUUPo2NjVi5ciVeeeUVJA4w47T3HcXR3GU8UM26devgcrnCW2NjYyxt0wgxzmDAFVGO+3SGQnBxDFHRYgqfmpoatLa2Ijc3F1qtFlqtFtXV1Xj++eeh1WrDRzy9j2BaW1vDr9ntdvj9frS1tV20pjeDwQCLxRKx0TePVpIwLsr5PgJAtcczvA3RZYkpfObNm4e6ujrU1taGtxkzZuC+++5DbW0txo0bB7vdjqqqqvD3+P1+VFdXY+bMmQCA3Nxc6HS6iJqWlhbU19eHa4j6owWQbzZHXX+kq4uX2xUspsXEzGYzcnJyIvaZTCakpaWF95eUlKC0tBTZ2dnIzs5GaWkpkpKSsHjxYgCA1WrF0qVLsXr1aqSlpSE1NRVr1qzBlClT+gxgE32dJElI02qhAaK6fcIZCMAdCsESwyV6ip8hX8nw0UcfhdfrxcMPP4y2tjbcfPPN2Lt3L8xf+x/rmWeegVarxb333guv14t58+bh5ZdfRgL/kdAlzDKbkarV4nzPpZcMq+nqQrPfD4vRGIfOKFaSUOFxaUdHB6xWK1wuF8d/vmFcwSAmnziB5ijW7NEC+GDiRNySnDz8jRGA2H43eW8XqUqiJCEvvKzqwEIAPuLaPorF8CFV0UsSJkV5GhUCcNzrRUh9B/ffCAwfUhVJkjDeYIh6sLLB50MXn+WlSAwfUp1vmUxI0kT3T/eo14s2TjZUJIYPqY45IQHJUV4Z9YdCOMcF5RWJ4UOqc0VCAm6N8gpWtxA44HZzsqECMXxIdXSSBFsMC8o3+/1c01mBGD6kOpIk4dvJyVEPOh/u6oKXg86Kw/AhVZpsNCIhyuex/4tXvBSJ4UOqlKLVIiPKxcW8oRD+m8uqKg7Dh1QpPSEBk6OcbOgJhXC4s5ODzgrD8CFV0koSrophWdXmQICDzgrD8CFVkiQJBRYLohv1Aao6OhDgkY+iMHxItTL1ekS7CEtbMAgPB50VheFDqpWp1yN7gLXEv+58Tw+OdXUNc0cUC4YPqVaaVgtHlOM+XaEQzvApporC8CHVSvhqWdVo1XZ1gdGjHAwfUrV7ej3/bSAfeDy84qUgDB9StViOfNyhENqiWPuZ4oPhQ6p2XWIixkf5LK8zPh+Oe70c91EIhg+p2iitFqlRru3TA+BzHvkoBsOHVE0DYFpSUtT1lR0dw9cMxYThQ6o3O4ZH43zq83HQWSEYPqRqkiQhXauFPsrlNRr9frRwWVVFYPiQ6t1oMkV1k6mEL28wbfL7h78puqQhf1wyUbyZNZp+n2aRAMCg0SBDp8MkoxHXG434VnIypsYwRkTDh+FDqpfw1R3u//T5YE5IwDi9HtcnJWF6UhJmJCVhjF4PS0IC9JIEKcrTMxp+DB9SPQ2AH6WlYVZyMqYYjbDpdEjUaJAAMGwUjOFDqidJEqYmJfF0SmU44ExEsmD4EJEsGD5EJAuGDxHJguFDRLJg+BCRLBg+RCQLhg8RyYLhQ0SyUOUM5wvLYHZwYSgiRbnwOxnNUrWqDB+32w0AyMzMlLkTIuqP2+2G1WodsEYSKlxNOxQK4dSpU5g0aRIaGxthsVjkbilqHR0dyMzMZN9xwr7jSwgBt9sNh8MBTT/LnHydKo98NBoNrrrqKgCAxWJR1V/OBew7vth3/FzqiOcCDjgTkSwYPkQkC9WGj8FgwPr162GI8oFxSsG+44t9K5cqB5yJSP1Ue+RDROrG8CEiWTB8iEgWDB8ikoUqw2fr1q3IyspCYmIicnNz8cEHH8jaz/vvv4+7774bDocDkiThzTffjHhdCIENGzbA4XDAaDRizpw5OHHiRESNz+dDcXEx0tPTYTKZsGDBAjQ1NQ1r32VlZbjxxhthNpsxevRoLFy4EKdOnVJ879u2bcPUqVPDE/Dy8vLwzjvvKLrn/pSVlUGSJJSUlKiu9yEhVKa8vFzodDrx4osvipMnT4qVK1cKk8kkzpw5I1tPe/bsEY8//rjYtWuXACAqKioiXt+4caMwm81i165doq6uTixatEhceeWVoqOjI1yzfPlycdVVV4mqqipx5MgRMXfuXDFt2jTR09MzbH3ffvvtYufOnaK+vl7U1taKu+66S4wZM0Z4PB5F9757927x9ttvi1OnTolTp06Jxx57TOh0OlFfX6/Ynnv7+OOPxdVXXy2mTp0qVq5cGd6vht6HiurC56abbhLLly+P2HfttdeKtWvXytRRpN7hEwqFhN1uFxs3bgzv6+7uFlarVWzfvl0IIUR7e7vQ6XSivLw8XNPc3Cw0Go2orKyMW++tra0CgKiurlZd7ykpKeKll15SRc9ut1tkZ2eLqqoqMXv27HD4qKH3oaSq0y6/34+amhoUFBRE7C8oKMChQ4dk6mpgDQ0NcDqdET0bDAbMnj073HNNTQ0CgUBEjcPhQE5OTlw/l8vlAgCkpqaqpvdgMIjy8nJ0dnYiLy9PFT2vWLECd911F/Lz8yP2q6H3oaSqG0vPnz+PYDAIm80Wsd9ms8HpdMrU1cAu9NVfz2fOnAnX6PV6pKSk9KmJ1+cSQmDVqlW49dZbkZOTE+7rQh+9+5K797q6OuTl5aG7uxvJycmoqKjApEmTwr+ASuwZAMrLy3HkyBEcPny4z2tK/vMeDqoKnwt6P39bCKH4Z3IPpud4fq6ioiIcP34cBw8e7POaEnufOHEiamtr0d7ejl27dmHJkiWorq4Ov67EnhsbG7Fy5Urs3bsXiYmJF61TYu/DQVWnXenp6UhISOiT8K2trX3+t1AKu90OAAP2bLfb4ff70dbWdtGa4VRcXIzdu3dj//79yMjICO9Xcu96vR4TJkzAjBkzUFZWhmnTpuG5555TdM81NTVobW1Fbm4utFottFotqqur8fzzz0Or1YZ/thJ7Hw6qCh+9Xo/c3FxUVVVF7K+qqsLMmTNl6mpgWVlZsNvtET37/X5UV1eHe87NzYVOp4uoaWlpQX19/bB+LiEEioqK8MYbb2Dfvn3IyspSTe+9CSHg8/kU3fO8efNQV1eH2tra8DZjxgzcd999qK2txbhx4xTb+7CQZ5x78C5cat+xY4c4efKkKCkpESaTSZw+fVq2ntxutzh69Kg4evSoACA2b94sjh49Gr78v3HjRmG1WsUbb7wh6urqxA9+8IN+L59mZGSId999Vxw5ckTcdtttw3759KGHHhJWq1UcOHBAtLS0hLeurq5wjRJ7X7dunXj//fdFQ0ODOH78uHjssceERqMRe/fuVWzPF/P1q11q6/1yqS58hBDit7/9rRg7dqzQ6/Vi+vTp4UvDctm/f78A0GdbsmSJEOLLS6jr168XdrtdGAwGMWvWLFFXVxfxHl6vVxQVFYnU1FRhNBpFYWGhOHv27LD23V/PAMTOnTvDNUrs/cEHHwz//Y8aNUrMmzcvHDxK7flieoePmnq/XFxSg4hkoaoxHyIaORg+RCQLhg8RyYLhQ0SyYPgQkSwYPkQkC4YPEcmC4UNEsmD4EJEsGD5EJAuGDxHJguFDRLL4/wajx+BPqdhMAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "env=gym.make(\"Acrobot-v1\",render_mode=\"rgb_array\")\n",
    "env.reset()\n",
    "gym_helper=GymHelper(env)\n",
    "for i in range(20):\n",
    "    gym_helper.render(title=str(i))\n",
    "    action=env.action_space.sample()\n",
    "    observation,reward,terminated,truncated,info=env.step(action)\n",
    "    done=terminated or truncated\n",
    "    if done:\n",
    "        break\n",
    "gym_helper.render(\"finished\")\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e430793",
   "metadata": {},
   "source": [
    "## MBPO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6dd9ee5a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T07:04:29.143447Z",
     "start_time": "2024-05-21T07:04:29.138436Z"
    }
   },
   "outputs": [],
   "source": [
    "#定义策略网络\n",
    "class PolicyModel(nn.Module):\n",
    "    def __init__(self,input_dim,output_dim):\n",
    "        super(PolicyModel,self).__init__()\n",
    "        #使用全连接层构建一个简单的神经网络，共享部分网络层\n",
    "        self.fc=nn.Sequential([\n",
    "            nn.Linear(input_dim,128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128,128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128,output_dim),\n",
    "            nn.Softmax(dim=1)\n",
    "        ])\n",
    "    def forward(self,x):\n",
    "        action_prob=self.fc(x)\n",
    "        return action_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "8d13f809",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T07:04:37.277020Z",
     "start_time": "2024-05-21T07:04:37.271507Z"
    }
   },
   "outputs": [],
   "source": [
    "#定义Q网络模型\n",
    "class QvalueModel(nn.Module):\n",
    "    def __init__(self,input_dim,output_dim):\n",
    "        super(QvalueModel,self).__init__()\n",
    "        #使用全连接层构建一个简单的神经网络，共享部分网络层\n",
    "        self.fc=nn.Sequential([\n",
    "            nn.Linear(input_dim,128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128,128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128,output_dim),\n",
    "        ])\n",
    "    def forward(self,x):\n",
    "        value=self.fc(x)\n",
    "        return value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9f6eb595",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T07:50:44.003821Z",
     "start_time": "2024-05-17T07:50:43.998311Z"
    }
   },
   "outputs": [],
   "source": [
    "class ReplayBuffer:\n",
    "    def __init__(self,max_size):\n",
    "        self.max_size=max_size\n",
    "        self.buffer=collections.deque(maxlen=self.max_size)\n",
    "    def add(self,state,action,reward,next_state,done):\n",
    "        experience=(state,action,reward,next_state,done)\n",
    "        self.buffer.append(experience)\n",
    "    def return_all(self):\n",
    "        batch=list(self.buffer)\n",
    "        state,action,reward,next_state,done=zip(*batch)\n",
    "        return np.array(state),np.array(action),np.array(reward),np.array(next_state),np.array(done)\n",
    "    def sample(self,batch_size):\n",
    "        if batch_size>len(self.buffer):\n",
    "            return self.return_all()\n",
    "        batch=random.sample(self.buffer,batch_size)\n",
    "        state,action,reward,next_state,done=zip(*batch)\n",
    "        return state,action,reward,next_state,done\n",
    "    def __len__(self):\n",
    "        return len(self.buffer)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "f2f8a49e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T07:07:17.382889Z",
     "start_time": "2024-05-21T07:07:17.362851Z"
    }
   },
   "outputs": [],
   "source": [
    "class SAC:\n",
    "    def __init__(self,env,lr=0.002,gamma=0.99,rho=0.01,buffer_size=10000):\n",
    "        self.env=env\n",
    "        self.gamma=gamma\n",
    "        self.rho=rho\n",
    "        #设置一个目标熵值,取负值转化为最小化问题\n",
    "        self.target_entropy=-np.log2(env.action_space.n)\n",
    "        #判断可用设备是GPU还是CPU\n",
    "        self.device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        self.actor=PolicyModel(env.observation_space.shape[0],env.action_space.n).to(self.device)\n",
    "        self.q1=QvalueModel(env.observation_space.shape[0],env.action_space.n).to(self.device)\n",
    "        self.q2=QvalueModel(env.observation_space.shape[0],env.action_space.n).to(self.device)\n",
    "        self.target_q1=QvalueModel(env.observation_space.shape[0],env.action_space.n).to(self.device)\n",
    "        self.target_q2=QvalueModel(env.observation_space.shape[0],env.action_space.n).to(self.device)\n",
    "        for param,target_param in zip(self.q1.parameters(),self.target_q1.parameters()):\n",
    "            target_param.data.copy_(param)\n",
    "        for param,target_param in zip(self.q2.parameters(),self.target_q2.parameters()):\n",
    "            target_param.data.copy_(param)\n",
    "        self.optimizer_actor=torch.optim.Adam(self.actor.parameters(),lr=lr)\n",
    "        self.optimizer_q1=torch.optim.Adam(self.q1.parameters(),lr=lr)\n",
    "        self.optimizer_q2=torch.optim.Adam(self.q2.parameters(),lr=lr)\n",
    "        #alpha做作为可学习参数,学习其对数值可确保alpha=exp(log_alpha)>0\n",
    "        self.log_alpha=torch.tensor([0.0],device=self.device,requires_grad=True)\n",
    "        self.optimizer_log_alpha=torch.optim.Adam([self.log_alpha],lr=lr)\n",
    "    def choose_action(self,state):\n",
    "        state=torch.FloatTensor(np.array([state])).to(self.device)\n",
    "        with torch.no_grad():\n",
    "            action_prob,_=self.ac(state)\n",
    "        c=torch.distributions.Categorical(action_prob)\n",
    "        action=c.sample()\n",
    "        return action.item()\n",
    "    def update(self,buffer):\n",
    "        states,actions,rewards,next_states,dones=buffer\n",
    "        states=torch.FloatTensor(np.array(states)).to(self.device)\n",
    "        actions=torch.FloatTensor(np.array(actions)).view(-1,1).to(self.device)\n",
    "        rewards=torch.FloatTensor(np.array(rewards)).view(-1,1).to(self.device)\n",
    "        next_states=torch.FloatTensor(np.array(next_states)).to(self.device)\n",
    "        dones=torch.FloatTensor(np.array(dones)).view(-1,1).to(self.device)\n",
    "        #计算下一状态的动作概率与对数\n",
    "        next_action_prob=self.actor(next_states)\n",
    "        log_next_prob=torch.log(next_action_prob+1e-9)\n",
    "        #计算目标Q值\n",
    "        target_q1=self.target_q1(next_states)\n",
    "        target_q2=self.target_q2(next_states)\n",
    "        target_q_min=torch.min(target_q1,target_q2)\n",
    "        min_q_next_targets=next_action_prob*(target_q_min-torch.exp(self.log_alpha)*log_next_prob)\n",
    "        min_q_next_targets=torch.sum(min_q_next_targets,dim=1,keepdim=True)\n",
    "        #计算TD目标\n",
    "        td_target=rewards+(1-done)*self.gamma*min_q_next_targets\n",
    "        #计算Q网络的loss\n",
    "        q1=self.q1(states)\n",
    "        q2=self.q2(states)\n",
    "        q1_loss=F.mse_loss(q1.gather(1,actions),td_target.detach()).mean()\n",
    "        q2_loss=F.mse_loss(q2.gather(1,actions),td_target.detach()).mean()\n",
    "        #梯度清零，反向传播,更新参数\n",
    "        self.optimizer_q1.zero_grad()\n",
    "        self.optimizer_q2.zero_grad()\n",
    "        q1_loss.backend()\n",
    "        q2_loss.backend()\n",
    "        self.optimizer_q1.step()\n",
    "        self.optimizer_q2.step()\n",
    "        #actor网络更新\n",
    "        action_prob=self.actor(states)\n",
    "        log_prob=torch.log(action_prob+1e-9)\n",
    "        #计算损失\n",
    "        q1=self.q1(states)\n",
    "        q2=self.q2(states)\n",
    "        inside_term=torch.exp(self.log_alpha)*log_prob-torch.min(q1,q2)\n",
    "        actor_loss=torch.sum(action_prob*inside_term,dim=1,keepdim=True).mean()\n",
    "        #梯度清零，反向传播,更新参数\n",
    "        self.optimizer_actor.zero_grad()\n",
    "        actor_loss.backend()\n",
    "        self.optimizer_actor.step()\n",
    "        #计算alpha的loss\n",
    "        inside_term=-torch.sum(action_prob*log_prob,dim=1,keepdim=True)-self.target_entropy\n",
    "        alpha_loss=(torch.exp(self.log_prob)*inside_term.detach()).mean()\n",
    "        #梯度清零，反向传播,更新参数\n",
    "        self.optimizer_actor.zero_grad()\n",
    "        alpha_loss.backend()\n",
    "        self.optimizer_actor.step()\n",
    "        for param,target_param in zip(self.q1.parameters(),self.target_q1.parameters()):\n",
    "            target_param.data.copy_(param.data*self.rho+target_param.data*(1-self.rho))\n",
    "        for param,target_param in zip(self.q2.parameters(),self.target_q2.parameters()):\n",
    "            target_param.data.copy_(param.data*self.rho+target_param.data*(1-self.rho))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "0e1ee0ec",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T07:07:51.261825Z",
     "start_time": "2024-05-21T07:07:51.219752Z"
    }
   },
   "outputs": [],
   "source": [
    "device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "c250d135",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T07:13:43.839034Z",
     "start_time": "2024-05-21T07:13:43.833516Z"
    }
   },
   "outputs": [],
   "source": [
    "class EnsembleFC(nn.Module):\n",
    "    def __init__(self,input_dim,output_dim,ensemble_size):\n",
    "        super(EnsembleFC,self).__init__()\n",
    "        self.input_dim=input_dim\n",
    "        self.output_dim=output_dim\n",
    "        #定义集成模型的数量\n",
    "        self.ensemble_size=ensemble_size\n",
    "        #权重矩阵大小为(ensemble_size,input_dim,output_dim)\n",
    "        #偏置矩阵(ensemble_size,1,output_dim)\n",
    "        self.weight=nn.Parameter(torch.Tensor(ensemble_size,input_dim,output_dim).to(device))\n",
    "        self.bias=nn.Parameter(torch.Tensor(ensemble_size,input_dim,output_dim).to(device))\n",
    "    def forward(self,x):\n",
    "        #矩阵相乘\n",
    "        w_time_x=torch.bmm(x,self.weight)\n",
    "        return torch.add(w_time_x,self.bias[:,None,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "8ccb8014",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T07:33:26.544172Z",
     "start_time": "2024-05-21T07:33:26.527128Z"
    }
   },
   "outputs": [],
   "source": [
    "class EnsembleModel(nn.Module):\n",
    "    def __init__(self,state_size,action_size,ensemble_size,reward_size=1,hidden_size=128,lr=0.0001):\n",
    "        super(EnsembleModel,self).__init__()\n",
    "        self.hidden_size=hidden_size\n",
    "        self.input_dim=state_size+action_size\n",
    "        self.output_dim=state_size+reward_size\n",
    "        self.device=device\n",
    "        #输出为均值和方差\n",
    "        self.nn1=EnsembleFC(self.input_dim,hidden_size,ensemble_size)\n",
    "        self.nn2=EnsembleFC(self.input_dim,hidden_size*2,ensemble_size)\n",
    "        self.nn3=EnsembleFC(hidden_size*2,hidden_size,ensemble_size)\n",
    "        self.nn4=EnsembleFC(hidden_size,self.output_dim*2,ensemble_size)\n",
    "        #初始化log方差的最大值最小值\n",
    "        self.max_logvar=nn.Parameter((torch.ones((1,self.output_dim)).float()/2).to(self.device),requires_grad=True)\n",
    "        self.min_logvar=nn.Parameter((torch.zeros((1,self.output_dim)).float()/2).to(self.device),requires_grad=True)\n",
    "        #定义adam优化器\n",
    "        self.optimizer=torch.optim.Adam(self.parameters(),lr=lr)\n",
    "        \n",
    "        #初始化权重方法\n",
    "        def init_weights(m):\n",
    "            if(type(m))==nn.Linear or isinstance(m,EnsembleFC):\n",
    "                torch.nn.init.kaiming_normal_(m.weight,nonlinearity=\"relu\")\n",
    "                torch.nn.init.constant_(m.bias,0)\n",
    "        self.apply(init_weights)\n",
    "    def forward(self,x,ret_log_var=False):\n",
    "        x=torch.relu(self.nn1(x))\n",
    "        x=torch.relu(self.nn2(x))\n",
    "        x=torch.relu(self.nn3(x))\n",
    "        x=self.nn4(x)\n",
    "        #计算预测的均值\n",
    "        mean=x[:,:,:self.output_dim]\n",
    "        #计算预测的log方差,并通过softplus函数进行变换保证log方差平滑落在最大最小值之间\n",
    "        logvar=self.max_logvar-F.softplus(self.max_logvar-x[:,:,self.output_dim:])\n",
    "        logvar=self.min_logvar-F.softplus(logvar-self.min_logvar)\n",
    "        #根据差异选择返回方差的对数还是原始值\n",
    "        if ret_log_var:\n",
    "            return mean,logvar\n",
    "        else:\n",
    "            return mean,torch.exp(logvar)\n",
    "    def loss(self,mean,logvar,labels):\n",
    "        #计算方差的倒数,用来作为损失函数的一个权重,增加方差较小的样本属性\n",
    "        inv_var=torch.exp(-logvar)\n",
    "        mean_loss=torch.mean(torch.mean(torch.pow(mean-labels,2)*inv_var,dim=-1),dim=-1)\n",
    "        var_loss=torch.mean(torch.mean(logvar,dim=-1),dim=-1)\n",
    "        total_loss=torch.sum(mean_loss)+torch.sum(var_loss)\n",
    "        return total_loss,mean_loss\n",
    "    def train(self,loss):\n",
    "        self.optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        self.optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff43391e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
